import numpy as np
import torch
import gym
import argparse
import os
import d4rl
from tqdm import trange
from coolname import generate_slug
import time
import json
from log import Logger

import utils
from utils import VideoRecorder, get_expert_traj, merge_trajectories, get_dataset_return
from utils import load_trajectories_with_goals
import IQL
from scipy.spatial import KDTree
import sys
import datetime
from typing import Dict, Tuple, Optional
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import reward_design.utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment

# def squashing_func(distance, action_dim, beta=0.5, scale=1.0, no_action_dim=False):
#     if no_action_dim:
#         squashed_value = scale * np.exp(-beta * distance)
#     else:
#         squashed_value = scale * np.exp(-beta * distance/action_dim)
    
#     return squashed_value
    
# def rewarder(kd_tree, key, num_k, action_dim, beta, scale, no_action_dim=False):
    
#     distance, _ = kd_tree.query(key, k=[num_k], workers=-1)
#     reward = squashing_func(distance, action_dim, beta, scale, no_action_dim)
#     return reward
        

def calculate_reward(obs, action, next_obs, local_dict, design_mode):
    if design_mode == "sa":
        reward = local_dict["compute_dense_reward"](obs, action)
    elif design_mode == "sas":
        reward = local_dict["compute_dense_reward"](obs, action, next_obs)
    elif design_mode == "ss":
        reward = local_dict["compute_dense_reward"](obs, next_obs)
    else:
        raise NotImplementedError
    return reward
        

def eval_policy(args, iter, video: VideoRecorder, logger: Logger, 
                policy, env_name, seed, mean, std, seed_offset=100, 
                eval_episodes=10):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + seed_offset)

    lengths = []
    returns = []
    avg_reward = 0.
    for _ in range(eval_episodes):
        # video.init(enabled=(args.save_video and _ == 0))
        state, done = eval_env.reset(), False
        # video.record(eval_env)
        steps = 0
        episode_return = 0
        while not done:
            state = (np.array(state).reshape(1, -1) - mean)/std
            action = policy.select_action(state)
            state, reward, done, _ = eval_env.step(action)
            # video.record(eval_env)
            avg_reward += reward
            episode_return += reward
            steps += 1
        lengths.append(steps)
        returns.append(episode_return)
        # video.save(f'eval_s{iter}_r{str(episode_return)}.mp4')

    avg_reward /= eval_episodes
    d4rl_score = eval_env.get_normalized_score(avg_reward) * 100

    logger.log('eval/offline lengths_mean', np.mean(lengths), iter)
    logger.log('eval/offline returns_mean', np.mean(returns), iter)
    logger.log('eval/offline d4rl_score', d4rl_score, iter)

    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {d4rl_score:.3f}")
    print("---------------------------------------")
    return d4rl_score


def eval_policy_for_cal(local_dict, design_mode, args, iter, video: VideoRecorder, logger: Logger, 
                policy, env_name, seed, mean, std, seed_offset=100, 
                eval_episodes=10):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + seed_offset)

    lengths = []
    returns = []
    cal_returns = []
    avg_reward = 0.
    for _ in range(eval_episodes):
        # video.init(enabled=(args.save_video and _ == 0))
        state, done = eval_env.reset(), False
        # video.record(eval_env)
        steps = 0
        episode_return = 0
        cal_episode_return = 0
        while not done:
            normalized_state = (np.array(state).reshape(1, -1) - mean)/std
            action = policy.select_action(normalized_state)
            next_state, reward, done, _ = eval_env.step(action)
            cal_reward = calculate_reward(
                obs=state,
                action=action,
                next_obs=next_state,
                local_dict=local_dict,
                design_mode=design_mode
            )
            # video.record(eval_env)
            avg_reward += reward
            episode_return += reward
            cal_episode_return += cal_reward
            steps += 1
            state = next_state
        lengths.append(steps)
        returns.append(episode_return)
        cal_returns.append(cal_episode_return)
        # video.save(f'eval_s{iter}_r{str(episode_return)}.mp4')

    avg_reward /= eval_episodes
    d4rl_score = eval_env.get_normalized_score(avg_reward) * 100

    logger.log('eval/offline lengths_mean', np.mean(lengths), iter)
    logger.log('eval/offline returns_mean', np.mean(returns), iter)
    logger.log('eval/offline cal_returns_mean', np.mean(cal_returns), iter)
    logger.log('eval/offline d4rl_score', d4rl_score, iter)

    print()
    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {d4rl_score:.3f}")
    print("---------------------------------------")
    return d4rl_score

def normalize_func(dataset: utils.ReplayBuffer):

    trajs = utils.split_into_trajectories(
        observations=dataset.state,
        actions=dataset.action,
        rewards=dataset.reward,
        masks=dataset.not_done,
        dones_float=dataset.dones_float_for_returns,
        next_observations=dataset.next_state
    )

    def compute_returns(traj):
        episode_return = 0
        for _, _, rew, _, _, _ in traj:
            episode_return += rew

        return episode_return.item()


    reward_min = dataset.reward.min()
    reward_max = dataset.reward.max()
    dataset.reward = (dataset.reward - reward_min) / (reward_max - reward_min)

    return None
    # return coefft

if __name__ == "__main__":
    start_time = time.time()

    parser = argparse.ArgumentParser()
    # Experiment
    parser.add_argument("--policy", default="IQL")               # Policy name
    parser.add_argument("--env", default="halfcheetah-medium-v2")        # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment
    parser.add_argument("--save_model", action="store_true", default=False)        # Save model and optimizer parameters
    parser.add_argument('--eval_episodes', default=10, type=int)
    parser.add_argument('--save_video', default=False, action='store_true')
    parser.add_argument("--normalize", default=False, action='store_true')
    parser.add_argument("--normalize_reward", type=reward_design.utils.str2bool, default=True)
    parser.add_argument("--k", default=1, type=int)                 # how many nearest neighbors are needed
    parser.add_argument("--beta", default=0.5, type=float)                      # coefficient in distance
    # parser.add_argument("--scale", default=1.0, type=float)                    # scale of the reward function
    # IQL
    parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
    parser.add_argument("--temperature", default=3.0, type=float)
    parser.add_argument("--expectile", default=0.7, type=float)
    parser.add_argument("--tau", default=0.005, type=float)
    parser.add_argument("--discount", default=0.99, type=float)     # Discount factor
    # Work dir
    parser.add_argument('--work_dir', default='tmp', type=str)
    parser.add_argument('--expl_noise', default=0.2, type=float)
    # parser.add_argument('--mode', default='sas', type=str) # different modes of search, support sas, sa, ss
    parser.add_argument("--no_action_dim", action="store_true", default=False)     # whether to involve action dimension
    parser.add_argument("--dropout_rate", default=None)
    parser.add_argument("--bias", default=1.0, type=float)

    # Reward design
    parser.add_argument(
        "--use_oracle", type=reward_design.utils.str2bool, default=False, help="Path to the reward function."
    )
    parser.add_argument(
        "--reward_dir", type=str, help="Path to the reward function."
    )
    parser.add_argument(
        "--reward_type", type=str, choices=["best", "worst"], help="Type of the reward function."
    )
    parser.add_argument(
        "--reward_index", type=int, help="Index of the reward function."
    )
    parser.add_argument(
        "--current_time", type=str, help="Current time."
    )
    parser.add_argument(
        "--R_min", type=float, help="R_min."
    )
    parser.add_argument(
        "--R_max", type=float, help="R_max."
    )

    args = parser.parse_args()
    args.cooldir = generate_slug(2)

    # Build work dir
    base_dir = 'runs'

    if args.use_oracle:
        args.work_dir = os.path.join("oracle", args.env)
    else:
        with open(os.path.join(args.reward_dir, "keys.json"), "r", encoding="utf-8") as f:
            keys_data = json.load(f)
        with open(os.path.join(args.reward_dir, "buffer.json"), "r", encoding="utf-8") as f:
            buffer_data = json.load(f)
        args.env = keys_data["env_name"]
        base_dir = os.path.join(args.reward_dir, base_dir)

    utils.make_dir(base_dir)
    base_dir = os.path.join(base_dir, args.work_dir)
    utils.make_dir(base_dir)
    # args.work_dir = os.path.join(base_dir, args.env)
    args.work_dir = base_dir
    utils.make_dir(args.work_dir)

    # make directory
    ts = time.gmtime()
    ts = time.strftime("%m-%d-%H:%M", ts)
    # exp_name = str(args.env) + '-' + ts + '-bs' + str(args.batch_size) + '-s' + str(args.seed)
    exp_name = 'bs' + str(args.batch_size)
    if args.policy == 'IQL':
        exp_name += '-t' + str(args.temperature) + '-e' + str(args.expectile)
    else:
        raise NotImplementedError
    # exp_name += '-' + args.cooldir
    args.work_dir = os.path.join(args.work_dir, f"normalize_reward={args.normalize_reward}")
    args.work_dir = args.work_dir + '/' + exp_name
    args.work_dir = os.path.join(args.work_dir, args.current_time)
    args.work_dir = os.path.join(args.work_dir, f"type={args.reward_type},index={args.reward_index},seed={args.seed}")
    utils.make_dir(args.work_dir)
    os.makedirs(args.work_dir, exist_ok=True)

    args.model_dir = os.path.join(args.work_dir, 'model')
    utils.make_dir(args.model_dir)
    args.video_dir = os.path.join(args.work_dir, 'video')
    utils.make_dir(args.video_dir)

    file_logger = reward_design.utils.FileLogger(filename=os.path.join(args.work_dir, "rl_output.log"))
    sys.stdout = file_logger
    sys.stderr = file_logger

    with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    # utils.snapshot_src('.', os.path.join(args.work_dir, 'src'), '.gitignore')

    print("---------------------------------------")
    print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    env = gym.make(args.env)

    # Set seeds
    env.seed(args.seed)
    env.action_space.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
    print(
        "=" * 10
        + f"Load d4rl.qlearning_dataset"
        + "=" * 10
    )

    dataset = d4rl.qlearning_dataset(env)
    # 1.use origianl data
    replay_buffer.convert_D4RL(dataset)
    # 2.use data with infos
    print(
        "=" * 10
        + f"Loading dataset with goals"
        + "=" * 10
    )
    dataset_with_goals = load_trajectories_with_goals(args.env, env, dataset)
    # replay_buffer.add_goals(dataset_with_goals["goals"])
    goals = None
    if "antmaze" in args.env:
        with open(os.path.join(args.reward_dir, "llm_args.json"), "r", encoding="utf-8") as f:
            llm_args = json.load(f)
        if not llm_args["disable_goal"]:
            if llm_args["fix_goal"]:
                goals = reward_design.utils.get_antmaze_fix_goals(args.env, replay_buffer.state.shape[0])
                print("=" * 10 + f"Using fixed goals {goals.shape}" + "=" * 10)
                print("goals[:5] is:", goals[:5])
            else:
                goals = dataset_with_goals["goals"]
                print("=" * 10 + f"Using dynamic goals {goals.shape}" + "=" * 10)
                print("goals[:5] is:", goals[:5])
    replay_buffer.add_goals(goals)
    
    if args.normalize:
        mean, std = replay_buffer.normalize_states()
    else:
        mean, std = 0, 1

    data = None
    # data = get_expert_traj(args.env, env, dataset, num_top_episodes=1)
    # data = merge_trajectories(data)
    
    # if args.mode == 'sas':
    #     data = np.hstack([data[0], data[1], data[5]])  # stack state and action, and next state
    #     kd_tree = KDTree(data)
    #     # query every sample
    #     key = np.hstack([replay_buffer.state, replay_buffer.action, replay_buffer.next_state])
    # elif args.mode == 'sa':
    #     data = np.hstack([data[0], data[1]])  # stack state and action
    #     kd_tree = KDTree(data)
    #     # query every sample
    #     key = np.hstack([replay_buffer.state, replay_buffer.action])
    # elif args.mode == 'ss':
    #     data = np.hstack([data[0], data[5]])  # stack state and next state
    #     kd_tree = KDTree(data)
    #     # query every sample
    #     key = np.hstack([replay_buffer.state, replay_buffer.next_state])

    # reward = rewarder(kd_tree, key, args.k, action_dim, args.beta, args.scale, args.no_action_dim)
    if args.use_oracle:
        print("="*10 + "Use oracle" + "="*10)
    else:
        print("="*10 + "Use designed reward" + "="*10)
        design_mode = keys_data["design_mode"]
        reward_key = keys_data[args.reward_type][args.reward_index]
        # 1.load code from reward obj
        reward_code = buffer_data[reward_key]["code"]

        print("-"*10 + "Reward code is:" + "-"*10)
        print(reward_code)
        local_dict = {**globals(), **{"Dict": Dict, "Tuple": Tuple, "Optional": Optional}}
        exec(reward_code, local_dict)

        reward_list = []
        for obs, action, next_obs in zip(replay_buffer.state_all, replay_buffer.action, replay_buffer.next_state_all):
            r = calculate_reward(
                obs=obs,
                action=action,
                next_obs=next_obs,
                local_dict=local_dict,
                design_mode=design_mode
            )
            reward_list.append(r)
        replay_buffer.reward = np.array(reward_list).reshape(-1, 1)
    

    if args.normalize_reward:
        original_r = replay_buffer.reward.copy()
        if args.use_oracle:
            if 'antmaze' in args.env:
                replay_buffer.reward -= 1.0
            elif 'halfcheetah' in args.env or 'walker2d' in args.env or 'hopper' in args.env:
                coefft = normalize_func(replay_buffer)
        else:
            coefft = normalize_func(replay_buffer)

        dataset.reward = args.R_min + dataset.reward * (args.R_max - args.R_min)
        # replay_buffer.reward = replay_buffer.reward * 2
        # replay_buffer.reward -= max(replay_buffer.reward)


    kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        # IQL
        "discount": args.discount,
        "tau": args.tau,
        "temperature": args.temperature,
        "expectile": args.expectile,
        "dropout_rate": float(args.dropout_rate) if args.dropout_rate is not None else None,
    }

    # Initialize policy
    if args.policy == 'IQL':
        policy = IQL.IQL(data, **kwargs)
    else:
        raise NotImplementedError


    logger = Logger(args.work_dir, use_tb=True)
    video = VideoRecorder(dir_name=args.video_dir)

    results = []
    for t in trange(int(args.max_timesteps), mininterval=10):
        policy.train(replay_buffer, args.batch_size, logger=logger)
        # Evaluate episode
        if (t + 1) % args.eval_freq == 0:
            eval_episodes = 100 if t+1 == int(args.max_timesteps) and 'antmaze' in args.env else args.eval_episodes
            if args.use_oracle:
                d4rl_score = eval_policy(args, t+1, video, logger, policy, args.env,
                                        args.seed, mean, std, eval_episodes=eval_episodes)
            else:
                d4rl_score = eval_policy(args, t+1, video, logger, policy, args.env,
                                        args.seed, mean, std, eval_episodes=eval_episodes)
                # d4rl_score = eval_policy_for_cal(local_dict, design_mode, args, t+1, video, logger, policy, args.env,
                #                         args.seed, mean, std, eval_episodes=eval_episodes)
            results.append(d4rl_score)
            if args.save_model:
                policy.save(args.model_dir)

    results_all = {
        "last_d4rl_score": results[-1],
        "results": results,
    }
    with open(os.path.join(args.work_dir, "results.json"), "w", encoding="utf-8") as f:
        json.dump(results_all, f, ensure_ascii=False, indent=4)

    # calculate time
    end_time = time.time()
    elapsed_time_str = str(datetime.timedelta(seconds=int(end_time - start_time)))
    print(f"Elapsed Time: {elapsed_time_str}")
    print("=" * 10 + "finished" + "=" * 10)
    # Restore original stdout
    sys.stdout = file_logger.stdout
    sys.stderr = file_logger.stderr
    file_logger.close()